|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# |
| 3 | +# MobilityData 2025 |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import logging |
| 20 | +import os |
| 21 | +from typing import Dict, Optional |
| 22 | + |
| 23 | +import pandas as pd |
| 24 | +from sqlalchemy.exc import IntegrityError |
| 25 | +from sqlalchemy.orm import Session |
| 26 | + |
| 27 | +from shared.database.database import with_db_session |
| 28 | +from shared.database_gen.sqlacodegen_models import Feed, Redirectingid |
| 29 | + |
| 30 | +logger = logging.getLogger(__name__) |
| 31 | + |
| 32 | +TDG_REDIRECT_DATA_LINK = ( |
| 33 | + "https://raw.githubusercontent.com/MobilityData/mobility-feed-api/" |
| 34 | + "refs/heads/main/functions-data/tdg_feed_redirect/redirect_mdb_to_tdg.csv" |
| 35 | +) |
| 36 | + |
| 37 | +DEFAULT_COMMIT_BATCH_SIZE = 100 |
| 38 | + |
| 39 | + |
| 40 | +def _update_feed_redirect( |
| 41 | + db_session: Session, mdb_stable_id: str, tdg_stable_id: str |
| 42 | +) -> Dict[str, int]: |
| 43 | + """ |
| 44 | + Ensure there is a Redirectingid from MDB feed to TDG feed. |
| 45 | +
|
| 46 | + Returns a dict with counters: |
| 47 | + { |
| 48 | + "redirects_created": 0|1, |
| 49 | + "redirects_existing": 0|1, |
| 50 | + "missing_mdb_feeds": 0|1, |
| 51 | + "missing_tdg_feeds": 0|1, |
| 52 | + } |
| 53 | + """ |
| 54 | + counters = { |
| 55 | + "redirects_created": 0, |
| 56 | + "redirects_existing": 0, |
| 57 | + "missing_mdb_feeds": 0, |
| 58 | + "missing_tdg_feeds": 0, |
| 59 | + } |
| 60 | + |
| 61 | + mdb_feed: Feed | None = ( |
| 62 | + db_session.query(Feed).filter(Feed.stable_id == mdb_stable_id).one_or_none() |
| 63 | + ) |
| 64 | + if not mdb_feed: |
| 65 | + logger.warning( |
| 66 | + "MDB feed not found for stable_id=%s, skipping redirect", mdb_stable_id |
| 67 | + ) |
| 68 | + counters["missing_mdb_feeds"] = 1 |
| 69 | + return counters |
| 70 | + |
| 71 | + tdg_feed: Feed | None = ( |
| 72 | + db_session.query(Feed).filter(Feed.stable_id == tdg_stable_id).one_or_none() |
| 73 | + ) |
| 74 | + if not tdg_feed: |
| 75 | + logger.warning( |
| 76 | + "TDG feed not found for stable_id=%s, skipping redirect", tdg_stable_id |
| 77 | + ) |
| 78 | + counters["missing_tdg_feeds"] = 1 |
| 79 | + return counters |
| 80 | + |
| 81 | + # Both feeds exist: ensure redirect exists (source MDB → target TDG). |
| 82 | + redirect = ( |
| 83 | + db_session.query(Redirectingid) |
| 84 | + .filter( |
| 85 | + Redirectingid.target_id == tdg_feed.id, |
| 86 | + Redirectingid.source_id == mdb_feed.id, |
| 87 | + ) |
| 88 | + .one_or_none() |
| 89 | + ) |
| 90 | + |
| 91 | + if redirect: |
| 92 | + logger.info( |
| 93 | + "Redirect already exists: source=%s → target=%s", |
| 94 | + mdb_stable_id, |
| 95 | + tdg_stable_id, |
| 96 | + ) |
| 97 | + counters["redirects_existing"] = 1 |
| 98 | + return counters |
| 99 | + |
| 100 | + logger.info( |
| 101 | + "Creating redirect: source=%s → target=%s", |
| 102 | + mdb_stable_id, |
| 103 | + tdg_stable_id, |
| 104 | + ) |
| 105 | + redirect = Redirectingid( |
| 106 | + target_id=tdg_feed.id, |
| 107 | + source_id=mdb_feed.id, |
| 108 | + redirect_comment="Redirecting post TDG import", |
| 109 | + ) |
| 110 | + mdb_feed.status = "deprecated" |
| 111 | + db_session.add(redirect) |
| 112 | + counters["redirects_created"] = 1 |
| 113 | + return counters |
| 114 | + |
| 115 | + |
| 116 | +def commit_changes(db_session: Session, created_since_commit: int) -> None: |
| 117 | + """ |
| 118 | + Commit DB changes for redirects. |
| 119 | +
|
| 120 | + Mirrors the TDG import pattern: commit, rollback on IntegrityError. |
| 121 | + """ |
| 122 | + try: |
| 123 | + logger.info( |
| 124 | + "Committing DB changes after creating %d redirect(s)", created_since_commit |
| 125 | + ) |
| 126 | + db_session.commit() |
| 127 | + except IntegrityError: |
| 128 | + db_session.rollback() |
| 129 | + logger.exception( |
| 130 | + "Commit failed with IntegrityError; rolled back TDG redirects batch" |
| 131 | + ) |
| 132 | + |
| 133 | + |
| 134 | +@with_db_session |
| 135 | +def _update_tdg_redirects(db_session: Session, dry_run: bool = True) -> dict: |
| 136 | + """ |
| 137 | + Orchestrate TDG redirect updates: |
| 138 | + - Load redirect CSV |
| 139 | + - For each row, ensure redirect from MDB → TDG |
| 140 | + - Support dry_run and batch commits (COMMIT_BATCH_SIZE) |
| 141 | + """ |
| 142 | + logger.info("Starting TDG redirects update dry_run=%s", dry_run) |
| 143 | + |
| 144 | + try: |
| 145 | + df = pd.read_csv(TDG_REDIRECT_DATA_LINK) |
| 146 | + except Exception as e: |
| 147 | + logger.exception( |
| 148 | + "Failed to load TDG redirect CSV from %s", TDG_REDIRECT_DATA_LINK |
| 149 | + ) |
| 150 | + return { |
| 151 | + "message": "Failed to load TDG redirect CSV.", |
| 152 | + "error": str(e), |
| 153 | + "params": {"dry_run": dry_run}, |
| 154 | + "rows_processed": 0, |
| 155 | + "redirects_created": 0, |
| 156 | + "redirects_existing": 0, |
| 157 | + "missing_mdb_feeds": 0, |
| 158 | + "missing_tdg_feeds": 0, |
| 159 | + } |
| 160 | + |
| 161 | + commit_batch_size = int( |
| 162 | + os.getenv("COMMIT_BATCH_SIZE", str(DEFAULT_COMMIT_BATCH_SIZE)) |
| 163 | + ) |
| 164 | + logger.info("Commit batch size (env COMMIT_BATCH_SIZE)=%s", commit_batch_size) |
| 165 | + |
| 166 | + rows_processed = 0 |
| 167 | + redirects_created = 0 |
| 168 | + redirects_existing = 0 |
| 169 | + missing_mdb_feeds = 0 |
| 170 | + missing_tdg_feeds = 0 |
| 171 | + |
| 172 | + created_since_commit = 0 |
| 173 | + |
| 174 | + for idx, row in df.iterrows(): |
| 175 | + mdb_stable_id = row.get("MDB ID") |
| 176 | + tdg_ids_raw = row.get("TDG ID") |
| 177 | + |
| 178 | + if not isinstance(mdb_stable_id, str) or not isinstance(tdg_ids_raw, str): |
| 179 | + logger.warning( |
| 180 | + "Skipping row index=%s: invalid MDB/TDG IDs row=%s", |
| 181 | + idx, |
| 182 | + row.to_dict(), |
| 183 | + ) |
| 184 | + continue |
| 185 | + |
| 186 | + tdg_stable_ids = [ |
| 187 | + f"tdg-{stable_id.strip()}" |
| 188 | + for stable_id in tdg_ids_raw.split(",") |
| 189 | + if str(stable_id).strip() |
| 190 | + ] |
| 191 | + |
| 192 | + for tdg_stable_id in tdg_stable_ids: |
| 193 | + rows_processed += 1 |
| 194 | + logger.debug( |
| 195 | + "Processing redirect row: MDB=%s TDG=%s", |
| 196 | + mdb_stable_id, |
| 197 | + tdg_stable_id, |
| 198 | + ) |
| 199 | + |
| 200 | + counters = _update_feed_redirect( |
| 201 | + db_session=db_session, |
| 202 | + mdb_stable_id=mdb_stable_id, |
| 203 | + tdg_stable_id=tdg_stable_id, |
| 204 | + ) |
| 205 | + |
| 206 | + redirects_created += counters["redirects_created"] |
| 207 | + redirects_existing += counters["redirects_existing"] |
| 208 | + missing_mdb_feeds += counters["missing_mdb_feeds"] |
| 209 | + missing_tdg_feeds += counters["missing_tdg_feeds"] |
| 210 | + |
| 211 | + created_since_commit += counters["redirects_created"] |
| 212 | + |
| 213 | + if not dry_run and created_since_commit >= commit_batch_size: |
| 214 | + commit_changes(db_session, created_since_commit) |
| 215 | + created_since_commit = 0 |
| 216 | + |
| 217 | + if not dry_run and created_since_commit > 0: |
| 218 | + commit_changes(db_session, created_since_commit) |
| 219 | + |
| 220 | + message = ( |
| 221 | + "Dry run: no DB writes performed." |
| 222 | + if dry_run |
| 223 | + else "TDG redirects update executed successfully." |
| 224 | + ) |
| 225 | + summary = { |
| 226 | + "message": message, |
| 227 | + "rows_processed": rows_processed, |
| 228 | + "redirects_created": redirects_created, |
| 229 | + "redirects_existing": redirects_existing, |
| 230 | + "missing_mdb_feeds": missing_mdb_feeds, |
| 231 | + "missing_tdg_feeds": missing_tdg_feeds, |
| 232 | + "params": {"dry_run": dry_run}, |
| 233 | + } |
| 234 | + logger.info("TDG redirects update summary: %s", summary) |
| 235 | + return summary |
| 236 | + |
| 237 | + |
| 238 | +def update_tdg_redirects_handler(payload: Optional[dict] = None) -> dict: |
| 239 | + """ |
| 240 | + Cloud Function-style entrypoint. |
| 241 | +
|
| 242 | + Payload: {"dry_run": bool} (default True) |
| 243 | + """ |
| 244 | + payload = payload or {} |
| 245 | + logger.info("update_tdg_redirects_handler called with payload=%s", payload) |
| 246 | + |
| 247 | + dry_run_raw = payload.get("dry_run", True) |
| 248 | + dry_run = ( |
| 249 | + dry_run_raw |
| 250 | + if isinstance(dry_run_raw, bool) |
| 251 | + else str(dry_run_raw).lower() == "true" |
| 252 | + ) |
| 253 | + logger.info("Parsed dry_run=%s (raw=%s)", dry_run, dry_run_raw) |
| 254 | + |
| 255 | + result = _update_tdg_redirects(dry_run=dry_run) |
| 256 | + logger.info( |
| 257 | + "update_tdg_redirects_handler summary: %s", |
| 258 | + { |
| 259 | + k: result.get(k) |
| 260 | + for k in ( |
| 261 | + "message", |
| 262 | + "rows_processed", |
| 263 | + "redirects_created", |
| 264 | + "redirects_existing", |
| 265 | + "missing_mdb_feeds", |
| 266 | + "missing_tdg_feeds", |
| 267 | + ) |
| 268 | + }, |
| 269 | + ) |
| 270 | + return result |
0 commit comments