diff --git a/functions-python/tasks_executor/requirements.txt b/functions-python/tasks_executor/requirements.txt index ecdc13525..563cd57f6 100644 --- a/functions-python/tasks_executor/requirements.txt +++ b/functions-python/tasks_executor/requirements.txt @@ -28,3 +28,6 @@ google-cloud-storage # Configuration python-dotenv==1.0.0 pycountry + +# Other utilities +pandas diff --git a/functions-python/tasks_executor/src/main.py b/functions-python/tasks_executor/src/main.py index 622fab5ab..a1495b2f0 100644 --- a/functions-python/tasks_executor/src/main.py +++ b/functions-python/tasks_executor/src/main.py @@ -20,6 +20,7 @@ import functions_framework from shared.helpers.logger import init_logger +from tasks.data_import.transitfeeds.sync_transitfeeds import sync_transitfeeds_handler from tasks.dataset_files.rebuild_missing_dataset_files import ( rebuild_missing_dataset_files_handler, ) @@ -38,7 +39,7 @@ from tasks.geojson.update_geojson_files_precision import ( update_geojson_files_precision_handler, ) -from tasks.data_import.import_jbda_feeds import import_jbda_handler +from tasks.data_import.jbda.import_jbda_feeds import import_jbda_handler from tasks.licenses.populate_license_rules import ( populate_license_rules_handler, @@ -98,6 +99,10 @@ "description": "Populates licenses and license-rules in the database from a predefined JSON source.", "handler": populate_licenses_handler, }, + "sync_transitfeeds_data": { + "description": "Syncs data from TransitFeeds to the database.", + "handler": sync_transitfeeds_handler, + }, } diff --git a/functions-python/tasks_executor/src/tasks/data_import/data_import_utils.py b/functions-python/tasks_executor/src/tasks/data_import/data_import_utils.py new file mode 100644 index 000000000..b9ac2e6d5 --- /dev/null +++ b/functions-python/tasks_executor/src/tasks/data_import/data_import_utils.py @@ -0,0 +1,99 @@ +import logging +import uuid +from datetime import datetime +from typing import Tuple, Type, TypeVar + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from shared.database_gen.sqlacodegen_models import ( + Feed, + Officialstatushistory, + Entitytype, +) + +logger = logging.getLogger(__name__) +T = TypeVar("T", bound="Feed") + + +def _get_or_create_entity_type(session: Session, entity_type_name: str) -> Entitytype: + """Get or create an Entitytype by name.""" + logger.debug("Looking up Entitytype name=%s", entity_type_name) + et = session.scalar(select(Entitytype).where(Entitytype.name == entity_type_name)) + if et: + logger.debug("Found existing Entitytype name=%s", entity_type_name) + return et + et = Entitytype(name=entity_type_name) + session.add(et) + session.flush() + logger.info("Created Entitytype name=%s", entity_type_name) + return et + + +def get_feed( + session: Session, + stable_id: str, + model: Type[T] = Feed, +) -> T | None: + """Get a Feed by stable_id.""" + logger.debug("Lookup feed stable_id=%s", stable_id) + feed = session.scalar(select(model).where(model.stable_id == stable_id)) + if feed: + logger.debug("Found existing feed stable_id=%s id=%s", stable_id, feed.id) + else: + logger.debug("No Feed found with stable_id=%s", stable_id) + return feed + + +def _get_or_create_feed( + session: Session, + model: Type[T], + stable_id: str, + data_type: str, + is_official: bool = True, + official_notes: str = "Imported from JBDA as official feed.", + reviewer_email: str = "emma@mobilitydata.org", +) -> Tuple[T, bool]: + """Generic helper to get or create a Feed subclass (Gtfsfeed, Gtfsrealtimefeed) by stable_id.""" + logger.debug( + "Lookup feed model=%s stable_id=%s", + getattr(model, "__name__", str(model)), + stable_id, + ) + feed = session.scalar(select(model).where(model.stable_id == stable_id)) + if feed: + logger.info( + "Found existing %s stable_id=%s id=%s", + getattr(model, "__name__", str(model)), + stable_id, + feed.id, + ) + return feed, False + + new_id = str(uuid.uuid4()) + feed = model( + id=new_id, + data_type=data_type, + stable_id=stable_id, + official=is_official, + official_updated_at=datetime.now(), + ) + if is_official: + feed.officialstatushistories = [ + Officialstatushistory( + is_official=True, + reviewer_email=reviewer_email, + timestamp=datetime.now(), + notes=official_notes, + ) + ] + session.add(feed) + session.flush() + logger.info( + "Created %s stable_id=%s id=%s data_type=%s", + getattr(model, "__name__", str(model)), + stable_id, + new_id, + data_type, + ) + return feed, True diff --git a/functions-python/tasks_executor/src/tasks/data_import/import_jbda_feeds.py b/functions-python/tasks_executor/src/tasks/data_import/jbda/import_jbda_feeds.py similarity index 91% rename from functions-python/tasks_executor/src/tasks/data_import/import_jbda_feeds.py rename to functions-python/tasks_executor/src/tasks/data_import/jbda/import_jbda_feeds.py index 09cdae7d0..8373b1c09 100644 --- a/functions-python/tasks_executor/src/tasks/data_import/import_jbda_feeds.py +++ b/functions-python/tasks_executor/src/tasks/data_import/jbda/import_jbda_feeds.py @@ -20,13 +20,13 @@ import os import uuid from datetime import datetime -from typing import Optional, Tuple, Dict, Any, List, Final, Type, TypeVar +from typing import Optional, Tuple, Dict, Any, List, Final, TypeVar -import requests import pycountry +import requests from sqlalchemy import select, and_ -from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from shared.common.locations_utils import create_or_get_location from shared.database.database import with_db_session @@ -34,13 +34,14 @@ Feed, Gtfsfeed, Gtfsrealtimefeed, - Entitytype, Feedrelatedlink, Externalid, - Officialstatushistory, ) - from shared.helpers.pub_sub import trigger_dataset_download +from tasks.data_import.data_import_utils import ( + _get_or_create_entity_type, + _get_or_create_feed, +) T = TypeVar("T", bound="Feed") @@ -99,20 +100,6 @@ def import_jbda_handler(payload: dict | None = None) -> dict: return result -def _get_or_create_entity_type(session: Session, entity_type_name: str) -> Entitytype: - """Get or create an Entitytype by name.""" - logger.debug("Looking up Entitytype name=%s", entity_type_name) - et = session.scalar(select(Entitytype).where(Entitytype.name == entity_type_name)) - if et: - logger.debug("Found existing Entitytype name=%s", entity_type_name) - return et - et = Entitytype(name=entity_type_name) - session.add(et) - session.flush() - logger.info("Created Entitytype name=%s", entity_type_name) - return et - - def get_gtfs_file_url( detail_body: Dict[str, Any], rid: str = "current" ) -> Optional[str]: @@ -144,53 +131,6 @@ def get_gtfs_file_url( return None -def _get_or_create_feed( - session: Session, model: Type[T], stable_id: str, data_type: str -) -> Tuple[T, bool]: - """Generic helper to get or create a Feed subclass (Gtfsfeed, Gtfsrealtimefeed) by stable_id.""" - logger.debug( - "Lookup feed model=%s stable_id=%s", - getattr(model, "__name__", str(model)), - stable_id, - ) - feed = session.scalar(select(model).where(model.stable_id == stable_id)) - if feed: - logger.info( - "Found existing %s stable_id=%s id=%s", - getattr(model, "__name__", str(model)), - stable_id, - feed.id, - ) - return feed, False - - new_id = str(uuid.uuid4()) - feed = model( - id=new_id, - data_type=data_type, - stable_id=stable_id, - official=True, - official_updated_at=datetime.now(), - ) - feed.officialstatushistories = [ - Officialstatushistory( - is_official=True, - reviewer_email="emma@mobilitydata.org", - timestamp=datetime.now(), - notes="Imported from JBDA as official feed.", - ) - ] - session.add(feed) - session.flush() - logger.info( - "Created %s stable_id=%s id=%s data_type=%s", - getattr(model, "__name__", str(model)), - stable_id, - new_id, - data_type, - ) - return feed, True - - def _update_common_feed_fields( feed: Feed, list_item: dict, detail: dict, producer_url: str ) -> None: diff --git a/functions-python/tasks_executor/src/tasks/data_import/transitfeeds/sync_transitfeeds.py b/functions-python/tasks_executor/src/tasks/data_import/transitfeeds/sync_transitfeeds.py new file mode 100644 index 000000000..3e685df62 --- /dev/null +++ b/functions-python/tasks_executor/src/tasks/data_import/transitfeeds/sync_transitfeeds.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python3 +# +# MobilityData 2025 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import logging +import traceback +import uuid +from typing import Optional, Type, Callable, Dict + +import pandas as pd +from sqlalchemy.orm import Session + +from shared.common.locations_utils import create_or_get_location +from shared.database.database import with_db_session +from shared.database_gen.sqlacodegen_models import ( + Gtfsfeed, + Externalid, + Redirectingid, + Gtfsrealtimefeed, + Gtfsdataset, +) +from tasks.data_import.data_import_utils import ( + _get_or_create_feed, + _get_or_create_entity_type, + get_feed, +) + +logger = logging.getLogger(__name__) + + +def _safe_split(val: Optional[str], sep: str = " | ") -> list[str]: + """Split a string on `sep`, guarding against NaN/None and trimming parts.""" + if pd.isna(val) or val is None: + return [] + return [part.strip() for part in str(val).split(sep) if part and str(part).strip()] + + +def _process_feeds( + db_session: Session, + csv_url: str, + model_cls: Type, # Gtfsfeed or Gtfsrealtimefeed + feed_kind: str, # 'gtfs' or 'gtfs_rt' + dry_run: bool, + on_is_new: Optional[Callable[[Session, object, pd.Series], None]] = None, +) -> Dict[str, int | str]: + """Generic CSV → Feed loader for TransitFeeds imports.""" + try: + logger.info( + "Loading %s feeds from CSV: %s (dry_run=%s)", + feed_kind.upper(), + csv_url, + dry_run, + ) + df = pd.read_csv(csv_url) + logger.debug("CSV loaded: %d rows found for %s.", len(df), feed_kind.upper()) + + total_processed = total_created = total_updated = 0 + + for idx, row in df.iterrows(): + feed_stable_id = row["Mobility Database Feed ID"] + logger.debug( + "[%s][%d/%d] Processing feed_stable_id=%s", + feed_kind.upper(), + idx + 1, + len(df), + feed_stable_id, + ) + + feed, is_new = _get_or_create_feed( + db_session, model_cls, feed_stable_id, feed_kind, is_official=False + ) + # All TransitFeeds imports are marked deprecated + feed.status = "deprecated" + logger.info( + "[%s] %s feed %s", + feed_kind.upper(), + "Creating" if is_new else "Updating", + feed_stable_id, + ) + + # Init-on-create (shared fields) + if is_new: + feed.name = row["Feed Name"] + feed.externalids = [ + Externalid( + source="transitfeeds", associated_id=row["External Feed ID"] + ) + ] + feed.provider = row["Provider"] + feed.producer_url = row["Producer URL"] + logger.debug( + "[%s] Initialized new feed fields for %s (name=%s, provider=%s)", + feed_kind.upper(), + feed_stable_id, + feed.name, + feed.provider, + ) + + if on_is_new is not None: + try: + on_is_new(db_session, feed, row) + logger.debug( + "[%s] on_is_new callback executed for %s", + feed_kind.upper(), + feed_stable_id, + ) + except Exception as e: + logger.exception( + "[%s] on_is_new callback failed for %s: %s", + feed_kind.upper(), + feed_stable_id, + e, + ) + raise + + # Redirects (shared) + redirect_ids = _safe_split(row.get("Redirects")) + if redirect_ids: + feed.redirectingids.clear() + logger.info( + "[%s] Set %d redirecting ids for %s", + feed_kind.upper(), + len(redirect_ids), + feed_stable_id, + ) + logger.debug( + "[%s] %s redirect ids parsed: %s", + feed_kind.upper(), + feed_stable_id, + redirect_ids, + ) + try: + for target_id in redirect_ids: + target_feed = get_feed(db_session, target_id) + if not target_feed: + logger.warning( + "Redirect target feed not found for feed %s: %s", + feed_stable_id, + target_id, + ) + continue + feed.redirectingids.append( + Redirectingid( + target_id=target_feed.id, + redirect_comment="Deprecated historical feed from TransitFeeds", + ) + ) + except Exception as e: + logger.error( + "Redirect target feed not found for feed %s: %s", + feed_stable_id, + str(e), + ) + raise + + # Location (shared) + logger.debug( + "[%s] Resolving location for %s (Country=%s, Subdivision=%s, Municipality=%s)", + feed_kind.upper(), + feed_stable_id, + row["Country"], + row["Subdivision"], + row["Municipality"], + ) + location = create_or_get_location( + db_session, + country=row["Country"], + state_province=row["Subdivision"] + if not pd.isna(row["Subdivision"]) + else None, + city_name=row["Municipality"] + if not pd.isna(row["Municipality"]) + else None, + ) + if not getattr(feed, "locations", []) and location: + feed.locations = [location] + logger.debug( + "[%s] Assigned first location to %s", + feed_kind.upper(), + feed_stable_id, + ) + + total_processed += 1 + total_created += int(is_new) + total_updated += int(not is_new) + + if not dry_run: + logger.info( + "[%s] Committing DB transaction for %d processed feeds", + feed_kind.upper(), + total_processed, + ) + db_session.commit() + else: + logger.info( + "[%s] Dry-run enabled; no DB commit performed", feed_kind.upper() + ) + + logger.info( + "[%s] Done. processed=%d created=%d updated=%d", + feed_kind.upper(), + total_processed, + total_created, + total_updated, + ) + + return { + "message": f"Processed {total_processed} {feed_kind.upper()} feeds from TransitFeeds.", + "total_processed": total_processed, + "total_created": total_created, + "total_updated": total_updated, + } + + except Exception as error: + traceback.print_exc() + logger.exception("Error processing %s feeds: %s", feed_kind.upper(), error) + raise + + +def _process_transitfeeds_gtfs(db_session: Session, dry_run: bool) -> dict: + """Process TransitFeeds GTFS CSV.""" + logger.info("Starting GTFS feeds processing (dry_run=%s)", dry_run) + return _process_feeds( + db_session=db_session, + csv_url="https://raw.githubusercontent.com/MobilityData/mobility-feed-api/refs/heads/main/functions-data/" + "transitfeeds_data_import/gtfs_feeds.csv", + model_cls=Gtfsfeed, + feed_kind="gtfs", + dry_run=dry_run, + on_is_new=None, + ) + + +def _process_transitfeeds_gtfs_rt(db_session: Session, dry_run: bool) -> dict: + """Process TransitFeeds GTFS-RT CSV.""" + logger.info("Starting GTFS-RT feeds processing (dry_run=%s)", dry_run) + + def _rt_on_is_new(session: Session, feed, row: pd.Series) -> None: + entity_types = _safe_split(row.get("Entity Types")) + logger.debug( + "[GTFS_RT] %s entity types: %s", + row["Mobility Database Feed ID"], + entity_types, + ) + if entity_types: + feed.entitytypes = [ + _get_or_create_entity_type(session, et) for et in entity_types + ] + logger.info( + "[GTFS_RT] Set %d entity types for %s", + len(entity_types), + row["Mobility Database Feed ID"], + ) + + return _process_feeds( + db_session=db_session, + csv_url="https://raw.githubusercontent.com/MobilityData/mobility-feed-api/refs/heads/main/functions-data/" + "transitfeeds_data_import/gtfs_rt_feeds.csv", + model_cls=Gtfsrealtimefeed, + feed_kind="gtfs_rt", + dry_run=dry_run, + on_is_new=_rt_on_is_new, + ) + + +def _add_historical_datasets(db_session: Session, dry_run: bool) -> int: + """Create/attach historical datasets per feed (idempotent). Returns count added.""" + df = pd.read_csv( + "https://raw.githubusercontent.com/MobilityData/mobility-feed-api/refs/heads/main" + "/functions-data/transitfeeds_data_import/historical_datasets.csv" + ) + logger.debug("Historical datasets CSV loaded: %d rows", len(df)) + + total_added = 0 + grouped = df.groupby("Feed ID") + logger.debug("Grouped historical datasets by Feed ID: %d groups", len(grouped)) + + for _, grouped_df in grouped: + feed_stable_id = grouped_df["Feed ID"].iloc[0] + logger.debug( + "Processing historical datasets for feed_stable_id=%s (%d rows)", + feed_stable_id, + len(grouped_df), + ) + feed = get_feed(db_session, feed_stable_id, model=Gtfsfeed) + if not feed: + logger.warning( + "Feed with stable_id=%s not found; skipping historical datasets.", + feed_stable_id, + ) + continue + + # Newest first + grouped_df = grouped_df.sort_values( + by="Dataset ID", ascending=False + ).reset_index(drop=True) + + datasets: list[Gtfsdataset] = [] + latest_candidate_id: Optional[str] = None + latest_already_set = feed.latest_dataset_id is not None + + for i, (_, row) in enumerate(grouped_df.iterrows()): + tfs_dataset_id = row["Dataset ID"] + tfs_dataset_suffix = tfs_dataset_id.split("/")[-1] + mdb_dataset_stable_id = f"{feed_stable_id}-{tfs_dataset_suffix}" + + existing_dataset = ( + db_session.query(Gtfsdataset) + .filter(Gtfsdataset.stable_id == mdb_dataset_stable_id) + .first() + ) + + if existing_dataset: + logger.info( + "Historical dataset %s already exists; skipping creation.", + existing_dataset.stable_id, + ) + if (i == 0) and not latest_already_set and latest_candidate_id is None: + latest_candidate_id = existing_dataset.id + continue + + date_str = tfs_dataset_suffix.split("-")[0] + download_date = pd.to_datetime(date_str, format="%Y%m%d", errors="coerce") + if pd.isna(download_date): + try: + # Convert only if it's numeric + if date_str.isdigit(): + download_date = pd.to_datetime(int(date_str), unit="s") + else: + raise ValueError + except Exception: + logger.warning( + "Invalid date in Dataset ID %s; skipping.", tfs_dataset_id + ) + continue + + sdr_start = pd.to_datetime( + row["Service Date Range Start"], format="%Y%m%d", errors="coerce" + ) + sdr_end = pd.to_datetime( + row["Service Date Range End"], format="%Y%m%d", errors="coerce" + ) + + dataset_id = str(uuid.uuid4()) + ds = Gtfsdataset( + id=dataset_id, + stable_id=mdb_dataset_stable_id, + hosted_url=( + f"https://openmobilitydata-data.s3.us-west-1.amazonaws.com/public/feeds/" + f"{tfs_dataset_id}/gtfs.zip" + ), + downloaded_at=download_date, + service_date_range_start=None if pd.isna(sdr_start) else sdr_start, + service_date_range_end=None if pd.isna(sdr_end) else sdr_end, + feed_id=feed.id, + ) + datasets.append(ds) + logger.debug( + "Prepared new dataset %s (downloaded_at=%s) for feed %s", + ds.stable_id, + ds.downloaded_at, + feed_stable_id, + ) + + if (i == 0) and not latest_already_set and latest_candidate_id is None: + latest_candidate_id = dataset_id + + if datasets: + db_session.add_all(datasets) + logger.debug( + "Added %d new historical datasets for %s to the session.", + len(datasets), + feed_stable_id, + ) + else: + logger.debug("No new datasets to add for %s.", feed_stable_id) + + # Persist children before touching parent FK + db_session.flush() + + if not latest_already_set and latest_candidate_id: + feed.latest_dataset_id = latest_candidate_id + logger.info( + "Set latest_dataset_id for feed %s to %s", + feed_stable_id, + latest_candidate_id, + ) + db_session.flush() + + total_added += len(datasets) + logger.info( + "Assigned %d historical datasets to feed %s (latest_dataset_id=%s)", + len(datasets), + feed_stable_id, + feed.latest_dataset_id, + ) + + if not dry_run: + db_session.commit() + logger.debug("Committed historical datasets for %s", feed_stable_id) + else: + logger.debug("Dry-run: skipped commit for %s", feed_stable_id) + + logger.info("Finished adding historical datasets. total_added=%d", total_added) + return total_added + + +@with_db_session +def _sync_transitfeeds(db_session: Session, dry_run: bool = True) -> dict: + """Run the TransitFeeds sync end-to-end.""" + logger.info("Starting TransitFeeds sync (dry_run=%s)", dry_run) + gtfs_feeds_processing_result = _process_transitfeeds_gtfs( + db_session, dry_run=dry_run + ) + gtfs_rt_processing_result = _process_transitfeeds_gtfs_rt( + db_session, dry_run=dry_run + ) + + datasets_added = _add_historical_datasets(db_session, dry_run=dry_run) + + total_processed = ( + gtfs_feeds_processing_result["total_processed"] + + gtfs_rt_processing_result["total_processed"] + ) + logger.info( + "TransitFeeds sync complete. total_processed=%d datasets_added=%d", + total_processed, + datasets_added, + ) + + return { + "message": ( + f"Sync TransitFeeds completed. " + f"Total processed feeds: {total_processed}. " + f"Datasets added: {datasets_added}." + ), + "total_processed": total_processed, + "datasets_added": datasets_added, + "details": { + "gtfs_feeds": gtfs_feeds_processing_result, + "gtfs_rt_feeds": gtfs_rt_processing_result, + }, + } + + +def sync_transitfeeds_handler(payload: dict | None = None) -> dict: + """ + Cloud Function entrypoint. + Payload: {"dry_run": bool} (default True) + """ + payload = payload or {} + logger.info("sync_transitfeeds_handler called with payload=%s", payload) + + dry_run_raw = payload.get("dry_run", True) + dry_run = ( + dry_run_raw + if isinstance(dry_run_raw, bool) + else str(dry_run_raw).lower() == "true" + ) + logger.info("Parsed dry_run=%s (raw=%s)", dry_run, dry_run_raw) + + try: + result = _sync_transitfeeds(dry_run=dry_run) + except Exception as e: + logger.exception("Error during TransitFeeds sync: %s", e) + return { + "message": f"Error during TransitFeeds sync: {str(e)}", + "total_processed": 0, + "error": str(e), + } + + logger.info( + "sync_transitfeeds_handler summary: %s", + {k: result.get(k) for k in ("message", "total_processed", "datasets_added")}, + ) + return result diff --git a/functions-python/tasks_executor/tests/tasks/data_import/test_jbda_import.py b/functions-python/tasks_executor/tests/tasks/data_import/test_jbda_import.py index f143fd3b0..0fccf04d1 100644 --- a/functions-python/tasks_executor/tests/tasks/data_import/test_jbda_import.py +++ b/functions-python/tasks_executor/tests/tasks/data_import/test_jbda_import.py @@ -13,7 +13,7 @@ Feedrelatedlink, ) -from tasks.data_import.import_jbda_feeds import ( +from tasks.data_import.jbda.import_jbda_feeds import ( import_jbda_handler, get_gtfs_file_url, ) @@ -188,7 +188,7 @@ def _head_side_effect(url, allow_redirects=True, timeout=15): return _FakeResponse(status=404) with patch( - "tasks.data_import.import_jbda_feeds.requests.head", + "tasks.data_import.jbda.import_jbda_feeds.requests.head", side_effect=_head_side_effect, ): self.assertEqual(get_gtfs_file_url(detail, rid="current"), url_current) @@ -221,15 +221,15 @@ def _head_side_effect(url, allow_redirects=True, timeout=15): # Patch requests.Session and head; replace old pubsub mocks with trigger_dataset_download mock_trigger = MagicMock() with patch( - "tasks.data_import.import_jbda_feeds.requests.Session", + "tasks.data_import.jbda.import_jbda_feeds.requests.Session", return_value=_FakeSessionOK(), ), patch( - "tasks.data_import.import_jbda_feeds.requests.head", + "tasks.data_import.jbda.import_jbda_feeds.requests.head", side_effect=_head_side_effect, ), patch( - "tasks.data_import.import_jbda_feeds.REQUEST_TIMEOUT_S", 0.01 + "tasks.data_import.jbda.import_jbda_feeds.REQUEST_TIMEOUT_S", 0.01 ), patch( - "tasks.data_import.import_jbda_feeds.trigger_dataset_download", + "tasks.data_import.jbda.import_jbda_feeds.trigger_dataset_download", mock_trigger, ), patch.dict( os.environ, {"COMMIT_BATCH_SIZE": "1"}, clear=False @@ -311,9 +311,9 @@ def _head_side_effect(url, allow_redirects=True, timeout=15): @with_db_session(db_url=default_db_url) def test_import_http_failure_graceful(self, db_session: Session): with patch( - "tasks.data_import.import_jbda_feeds.requests.Session", + "tasks.data_import.jbda.import_jbda_feeds.requests.Session", return_value=_FakeSessionError(), - ), patch("tasks.data_import.import_jbda_feeds.REQUEST_TIMEOUT_S", 0.01): + ), patch("tasks.data_import.jbda.import_jbda_feeds.REQUEST_TIMEOUT_S", 0.01): out = import_jbda_handler({"dry_run": True}) self.assertEqual(out["message"], "Failed to fetch JBDA feeds.") diff --git a/functions-python/tasks_executor/tests/tasks/data_import/test_transitfeeds_sync.py b/functions-python/tasks_executor/tests/tasks/data_import/test_transitfeeds_sync.py new file mode 100644 index 000000000..f64728f2a --- /dev/null +++ b/functions-python/tasks_executor/tests/tasks/data_import/test_transitfeeds_sync.py @@ -0,0 +1,202 @@ +import os +import unittest +from unittest.mock import patch + +import pandas as pd +from sqlalchemy.orm import Session + +from shared.database.database import with_db_session +from shared.database_gen.sqlacodegen_models import ( + Gtfsfeed, + Gtfsrealtimefeed, + Gtfsdataset, +) +from tasks.data_import.transitfeeds.sync_transitfeeds import ( + sync_transitfeeds_handler, +) +from test_shared.test_utils.database_utils import default_db_url + + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers to fabricate CSV DataFrames +# ───────────────────────────────────────────────────────────────────────────── + + +def _df_gtfs_feeds() -> pd.DataFrame: + # Minimal columns used by the loader + return pd.DataFrame( + [ + { + "Mobility Database Feed ID": "mdb-123", + "Feed Name": "Sample Feed", + "External Feed ID": "tf-777", + "Provider": "Provider A", + "Producer URL": "https://example.com/a.zip", + "Redirects": "", # keep empty to avoid FK lookups + "Country": "Canada", + "Subdivision": "QC", + "Municipality": "Laval", + } + ] + ) + + +def _df_gtfs_rt_feeds() -> pd.DataFrame: + return pd.DataFrame( + [ + { + "Mobility Database Feed ID": "mdb-rt-1", + "Feed Name": "Sample RT Feed", + "External Feed ID": "tf-rt-999", + "Provider": "Provider A", + "Producer URL": "https://rt.example.com/endpoint", + "Redirects": "", + "Country": "Canada", + "Subdivision": "QC", + "Municipality": "Laval", + # Entity types separated by ' | ' to hit _safe_split + "Entity Types": "tu | vp", + } + ] + ) + + +def _df_historical_datasets() -> pd.DataFrame: + # Two datasets for mdb-123; the "Dataset ID" suffix encodes YYYYMMDD-... + # Code sorts descending by 'Dataset ID', so 20250101 comes before 20241201. + return pd.DataFrame( + [ + { + "Feed ID": "mdb-123", + "Dataset ID": "provider/feed/20250101-releaseA", + "Service Date Range Start": "20250101", + "Service Date Range End": "20250331", + }, + { + "Feed ID": "mdb-123", + "Dataset ID": "provider/feed/20241201-releaseB", + "Service Date Range Start": "20241201", + "Service Date Range End": "20250228", + }, + ] + ) + + +def _read_csv_side_effect(path: str, *args, **kwargs) -> pd.DataFrame: + fname = os.path.basename(path) + if fname == "gtfs_feeds.csv": + return _df_gtfs_feeds() + if fname == "gtfs_rt_feeds.csv": + return _df_gtfs_rt_feeds() + if fname == "historical_datasets.csv": + return _df_historical_datasets() + raise AssertionError(f"Unexpected CSV read: {path}") + + +# ───────────────────────────────────────────────────────────────────────────── +# Tests +# ───────────────────────────────────────────────────────────────────────────── + + +class TestTransitFeedsSync(unittest.TestCase): + @with_db_session(db_url=default_db_url) + def test_sync_creates_feeds_and_datasets(self, db_session: Session): + # Arrange CSVs + with patch( + "tasks.data_import.transitfeeds.sync_transitfeeds.pd.read_csv", + side_effect=_read_csv_side_effect, + ): + # Act + result = sync_transitfeeds_handler({"dry_run": False}) + + # Assert handler summary + self.assertIn("message", result) + self.assertIn("total_processed", result) + self.assertIn("datasets_added", result) + self.assertEqual(result["total_processed"], 2) # 1 GTFS + 1 RT row + self.assertEqual(result["datasets_added"], 2) + + # Verify GTFS feed was created + gtfs: Gtfsfeed | None = ( + db_session.query(Gtfsfeed).filter(Gtfsfeed.stable_id == "mdb-123").first() + ) + self.assertIsNotNone(gtfs) + self.assertEqual(gtfs.status, "deprecated") + # Externalids contains transitfeeds entry + self.assertTrue( + any( + e.source == "transitfeeds" and e.associated_id == "tf-777" + for e in (gtfs.externalids or []) + ) + ) + # Location assigned once + self.assertTrue(gtfs.locations) + self.assertEqual(getattr(gtfs.locations[0], "subdivision_name", None), "QC") + + # Historical datasets created (2) & latest_dataset_id points to newest + datasets = ( + db_session.query(Gtfsdataset).filter(Gtfsdataset.feed_id == gtfs.id).all() + ) + self.assertEqual(len(datasets), 2) + + newest_stable = "mdb-123-20250101-releaseA" + newest = next(ds for ds in datasets if ds.stable_id == newest_stable) + self.assertEqual(gtfs.latest_dataset_id, newest.id) + + # Verify RT feed creation and entity types + rt: Gtfsrealtimefeed | None = ( + db_session.query(Gtfsrealtimefeed) + .filter(Gtfsrealtimefeed.stable_id == "mdb-rt-1") + .first() + ) + self.assertIsNotNone(rt) + # entity_types relationship name is 'entity_types' in your code path + et_names = [getattr(et, "name", None) for et in (rt.entitytypes or [])] + self.assertCountEqual(et_names, ["tu", "vp"]) + + @with_db_session(db_url=default_db_url) + def test_sync_handles_missing_redirect_targets_and_empty_history( + self, db_session: Session + ): + # Build CSVs where GTFS has a bogus redirect; historical is empty + df_gtfs = _df_gtfs_feeds().copy() + df_gtfs.loc[0, "Redirects"] = "nonexistent-id" + df_rt = _df_gtfs_rt_feeds().copy() + df_hist = pd.DataFrame( + columns=[ + "Feed ID", + "Dataset ID", + "Service Date Range Start", + "Service Date Range End", + ] + ) + + def _side_effect(path: str, *args, **kwargs): + fname = os.path.basename(path) + if fname == "gtfs_feeds.csv": + return df_gtfs + if fname == "gtfs_rt_feeds.csv": + return df_rt + if fname == "historical_datasets.csv": + return df_hist + raise AssertionError(f"Unexpected CSV read: {path}") + + with patch( + "tasks.data_import.transitfeeds.sync_transitfeeds.pd.read_csv", + side_effect=_side_effect, + ): + out = sync_transitfeeds_handler({"dry_run": False}) + + # We processed 2 feeds; datasets_added is 0 + self.assertEqual(out["total_processed"], 2) + self.assertEqual(out["datasets_added"], 0) + + # GTFS feed exists even though redirect target was missing + gtfs = ( + db_session.query(Gtfsfeed).filter(Gtfsfeed.stable_id == "mdb-123").first() + ) + self.assertIsNotNone(gtfs) + + +if __name__ == "__main__": + unittest.main()