Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 93 additions & 28 deletions api/src/shared/common/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import os
from typing import Iterator, List, Dict, Optional

from geoalchemy2 import WKTElement
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import joinedload, Session, contains_eager, load_only
from sqlalchemy.orm import joinedload, Session, contains_eager, load_only, selectinload
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.strategy_options import _AbstractLoad
from sqlalchemy import func
Expand Down Expand Up @@ -48,7 +49,7 @@ def get_gtfs_feeds_query(
is_official: bool | None = None,
published_only: bool = True,
include_options_for_joinedload: bool = True,
) -> Query[any]:
) -> Query:
"""Get the DB query to use to retrieve the GTFS feeds.."""
gtfs_feed_filter = GtfsFeedFilter(
stable_id=stable_id,
Expand Down Expand Up @@ -159,36 +160,70 @@ def get_all_gtfs_feeds(

:return: The GTFS feeds in an iterator.
"""
batch_size = int(os.getenv("BATCH_SIZE", "500"))
batch_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size)
batch_size = int(os.getenv("BATCH_SIZE", "50"))

# We fetch in small batches and stream results to avoid loading the whole table in memory.
# stream_results=True lets SQLAlchemy iterate rows without buffering them all at once.
# We also clear the session cache between batches (see expunge_all() below) to prevent
# memory from growing indefinitely when many ORM objects are loaded.
batch_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).execution_options(stream_results=True)
if published_only:
batch_query = batch_query.filter(Gtfsfeed.operational_status == "published")

for batch in batched(batch_query, batch_size):
stable_ids = (f.stable_id for f in batch)
processed = 0

for batch_num, batch in enumerate(batched(batch_query, batch_size), start=1):
start_index = processed + 1
end_index = processed + len(batch)
logging.info("Processing feeds %d - %d", start_index, end_index)

# Convert to a list intentionally: we want to "materialize" IDs now to make any cost
# visible here (and keep the logic simple). This also avoids subtle lazy-evaluation
# effects that can hide where time/memory is really spent.
stable_ids = [f.stable_id for f in batch]
if not stable_ids:
processed += len(batch)
continue

if w_extracted_locations_only:
feed_query = apply_most_common_location_filter(db_session.query(Gtfsfeed), db_session)
yield from (
feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids)).options(
joinedload(Gtfsfeed.latest_dataset)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.features),
*get_joinedload_options(include_extracted_location_entities=True),
)
inner_q = feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids)).options(
# See note above: selectinload is chosen for collections to keep memory and row
# counts under control when streaming.
selectinload(Gtfsfeed.latest_dataset)
.selectinload(Gtfsdataset.validation_reports)
.selectinload(Validationreport.features),
selectinload(Gtfsfeed.bounding_box_dataset),
*get_selectinload_options(include_extracted_location_entities=True),
)
else:
yield from (
inner_q = (
db_session.query(Gtfsfeed)
.outerjoin(Gtfsfeed.gtfsdatasets)
.filter(Gtfsfeed.stable_id.in_(stable_ids))
.options(
joinedload(Gtfsfeed.latest_dataset)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.features),
*get_joinedload_options(include_extracted_location_entities=False),
selectinload(Gtfsfeed.latest_dataset)
.selectinload(Gtfsdataset.validation_reports)
.selectinload(Validationreport.features),
selectinload(Gtfsfeed.bounding_box_dataset),
*get_selectinload_options(include_extracted_location_entities=False),
)
)

# Iterate and stream rows out; the options above ensure related data is preloaded in
# a few small queries per batch, rather than one giant join.
for item in inner_q.execution_options(stream_results=True):
yield item

# Clear the Session identity map so objects from this batch can be GC'd. Without this,
# the Session will keep references and memory usage will grow with each batch.
try:
db_session.expunge_all()
except Exception:
logging.getLogger("get_all_gtfs_feeds").exception("Failed to expunge session after batch %d", batch_num)

processed += len(batch)


def get_gtfs_rt_feeds_query(
limit: int | None,
Expand Down Expand Up @@ -278,7 +313,10 @@ def get_all_gtfs_rt_feeds(
:return: The GTFS realtime feeds in an iterator.
"""
batched_query = (
db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size)
db_session.query(Gtfsrealtimefeed.stable_id)
.order_by(Gtfsrealtimefeed.stable_id)
.yield_per(batch_size)
.execution_options(stream_results=True)
)
if published_only:
batched_query = batched_query.filter(Gtfsrealtimefeed.operational_status == "published")
Expand All @@ -290,8 +328,8 @@ def get_all_gtfs_rt_feeds(
yield from (
feed_query.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
selectinload(Gtfsrealtimefeed.entitytypes),
selectinload(Gtfsrealtimefeed.gtfs_feeds),
*get_joinedload_options(include_extracted_location_entities=True),
)
.order_by(Gtfsfeed.stable_id)
Expand All @@ -301,9 +339,9 @@ def get_all_gtfs_rt_feeds(
db_session.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*get_joinedload_options(include_extracted_location_entities=False),
selectinload(Gtfsrealtimefeed.entitytypes),
selectinload(Gtfsrealtimefeed.gtfs_feeds),
*get_selectinload_options(include_extracted_location_entities=False),
)
)

Expand All @@ -319,10 +357,10 @@ def apply_bounding_filtering(
if not bounding_latitudes or not bounding_longitudes or not bounding_filter_method:
return query

if (
len(bounding_latitudes_tokens := bounding_latitudes.split(",")) != 2
or len(bounding_longitudes_tokens := bounding_longitudes.split(",")) != 2
):
# Parse tokens explicitly to satisfy static analyzers and keep error messages clear.
bounding_latitudes_tokens = bounding_latitudes.split(",")
bounding_longitudes_tokens = bounding_longitudes.split(",")
if len(bounding_latitudes_tokens) != 2 or len(bounding_longitudes_tokens) != 2:
raise_internal_http_validation_error(
invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes)
)
Expand Down Expand Up @@ -385,6 +423,33 @@ def get_joinedload_options(include_extracted_location_entities: bool = False) ->
]


def get_selectinload_options(include_extracted_location_entities: bool = False) -> List[_AbstractLoad]:
"""
Returns common joinedload options for feeds queries.
:param include_extracted_location_entities: Whether to include extracted location entities.

:return: A list of joinedload options.
"""
# NOTE: For collections we prefer selectinload to avoid row explosion and high memory usage
# during streaming. When callers explicitly join some paths (e.g., most common locations),
# we use contains_eager on that specific path to tell SQLAlchemy the data came from a JOIN.
loaders = []
if include_extracted_location_entities:
loaders.append(contains_eager(Feed.feedosmlocationgroups).joinedload(Feedosmlocationgroup.group))

# collections -> selectinload; scalar relationships can remain joinedload
loaders.extend(
[
selectinload(Feed.locations),
selectinload(Feed.externalids),
selectinload(Feed.feedrelatedlinks),
selectinload(Feed.redirectingids).selectinload(Redirectingid.target),
selectinload(Feed.officialstatushistories),
]
)
return loaders


def get_gbfs_feeds_query(
db_session: Session,
stable_id: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion functions-python/export_csv/function_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "export-csv",
"description": "Export the DB feed data as a csv file",
"entry_point": "export_and_upload_csv",
"timeout": 600,
"timeout": 3600,
"memory": "2Gi",
"trigger_http": true,
"include_folders": ["helpers", "dataset_service"],
Expand Down
2 changes: 1 addition & 1 deletion functions-python/export_csv/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ python-dotenv==1.0.0

# Other dependencies
natsort

psutil
30 changes: 21 additions & 9 deletions functions-python/export_csv/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#
import argparse
import csv
import logging
import os
import re
from typing import Dict, Iterator, Optional
Expand All @@ -28,6 +27,7 @@
from google.cloud import storage
from geoalchemy2.shape import to_shape

from shared.helpers.runtime_metrics import track_metrics
from shared.database.database import with_db_session
from shared.helpers.logger import init_logger
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed, Feed
Expand All @@ -39,6 +39,8 @@

from shared.database_gen.sqlacodegen_models import Geopolygon

import logging

load_dotenv()
csv_default_file_path = "./output.csv"
init_logger()
Expand Down Expand Up @@ -124,6 +126,7 @@ def export_and_upload_csv(_):
return "Export successful"


@track_metrics(metrics=("time", "memory", "cpu"))
def export_csv(csv_file_path: str):
"""
Write feed data to a local CSV file.
Expand Down Expand Up @@ -318,7 +321,7 @@ def get_feed_csv_data(
) -> Dict:
"""
This function takes a generic feed and returns a dictionary with the data to be written to the CSV file.
Any specific data (for GTFS or GTFS_RT has to be added after this call.
Any specific data (for GTFS or GTFS_RT) has to be added after this call.
"""

redirect_ids = []
Expand Down Expand Up @@ -409,15 +412,24 @@ def get_gtfs_rt_feed_csv_data(
static_references = ""
first_feed_reference = None
if feed.gtfs_feeds:
valid_feed_references = [
feed_reference.stable_id.strip()
for feed_reference in feed.gtfs_feeds
if feed_reference and feed_reference.stable_id
# Prefer active feeds first using a stable sort so original relative order
# within active and inactive groups is preserved.
def _is_active(fr):
try:
return getattr(fr, "status", None) == "active"
except Exception:
return False

# Filter to valid references, then stable sort by active flag (True > False)
valid_refs = [
fr for fr in feed.gtfs_feeds if fr and getattr(fr, "stable_id", None)
]
sorted_refs = sorted(valid_refs, key=_is_active, reverse=True)

valid_feed_references = [fr.stable_id.strip() for fr in sorted_refs]
static_references = "|".join(valid_feed_references)
# If there is more than one GTFS feeds associated with this RT feed (why?)
# We will arbitrarily use the first one in the list for the bounding box.
first_feed_reference = feed.gtfs_feeds[0] if feed.gtfs_feeds else None
# First reference (after sort) will be an active one if any are present
first_feed_reference = sorted_refs[0] if sorted_refs else None
data["static_reference"] = static_references

# For the RT feed, we use the bounding box of the associated GTFS feed, if any.
Expand Down
63 changes: 43 additions & 20 deletions functions-python/export_csv/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ def populate_database(db_session):
fake = Faker()

feeds = []

# Put the deprecated feeds before the active feeds in the DB so they will be listed first
# in GtfsRealtimeFeed.gtfs_feeds (the RT feed references). This allows testing that active feeds
# are put first in GtfsRealtimeFeed.gtfs_feeds. Admittedly, it's a bit weak but it works for now.
for i in range(2):
feed = Gtfsfeed(
id=fake.uuid4(),
data_type="gtfs",
feed_name=f"deprecated-gtfs-{i} Some fake name",
note=f"deprecated-gtfs-{i} Some fake note",
producer_url=f"https://deprecated-gtfs-{i}_some_fake_producer_url",
authentication_type="0" if (i == 0) else "1",
authentication_info_url=None,
api_key_parameter_name=None,
license_url=f"https://gtfs-{i}_some_fake_license_url",
stable_id=f"deprecated-gtfs-{i}",
status="deprecated",
feed_contact_email=f"deprecated-gtfs-{i}[email protected]",
provider=f"deprecated-gtfs-{i} Some fake company",
operational_status="published",
official=True,
)
db_session.add(feed)

# We create 3 feeds. The first one is active. The third one is inactive and redirected to the first one.
# The second one is active but not redirected.
# First fill the generic parameters
Expand Down Expand Up @@ -97,25 +121,6 @@ def populate_database(db_session):
for feed in feeds:
db_session.add(feed)
db_session.flush()
for i in range(2):
feed = Gtfsfeed(
id=fake.uuid4(),
data_type="gtfs",
feed_name=f"gtfs-deprecated-{i} Some fake name",
note=f"gtfs-deprecated-{i} Some fake note",
producer_url=f"https://gtfs-deprecated-{i}_some_fake_producer_url",
authentication_type="0" if (i == 0) else "1",
authentication_info_url=None,
api_key_parameter_name=None,
license_url=f"https://gtfs-{i}_some_fake_license_url",
stable_id=f"gtfs-deprecated-{i}",
status="deprecated",
feed_contact_email=f"gtfs-deprecated-{i}[email protected]",
provider=f"gtfs-deprecated-{i} Some fake company",
operational_status="published",
official=True,
)
db_session.add(feed)

location_entity = Location(id="CA-quebec-montreal")

Expand Down Expand Up @@ -273,10 +278,28 @@ def populate_database(db_session):
entitytypes=[vp_entitytype, tu_entitytype] if i == 0 else [vp_entitytype],
operational_status="published",
official=True,
gtfs_feeds=[active_gtfs_feeds[0]] if i == 0 else [],
# Do not attach GTFS feeds at creation; we'll set them in a controlled order below
# gtfs_feeds=[],
)
gtfs_rt_feeds.append(feed)

db_session.add_all(gtfs_rt_feeds)

# --- Attach both a deprecated GTFS feed and an active GTFS feed to the first RT feed
try:
deprecated_feeds = (
db_session.query(Gtfsfeed)
.filter(Gtfsfeed.status == "deprecated")
.order_by(Gtfsfeed.stable_id)
.all()
)
if deprecated_feeds:
gtfs_rt_feeds[0].gtfs_feeds = [deprecated_feeds[0], active_gtfs_feeds[0]]
db_session.flush()
except Exception:
# Best effort in test setup; if it fails the rest of the tests will surface the issue.
pass

# Add redirecting IDs (from main branch logic)
gtfs_rt_feeds[1].redirectingids = [
Redirectingid(
Expand Down
6 changes: 3 additions & 3 deletions functions-python/export_csv/tests/test_export_csv_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
# the data is correct.
expected_csv = """
id,data_type,entity_type,location.country_code,location.subdivision_name,location.municipality,provider,is_official,name,note,feed_contact_email,static_reference,urls.direct_download,urls.authentication_type,urls.authentication_info,urls.api_key_parameter_name,urls.latest,urls.license,location.bounding_box.minimum_latitude,location.bounding_box.maximum_latitude,location.bounding_box.minimum_longitude,location.bounding_box.maximum_longitude,location.bounding_box.extracted_on,status,features,redirect.id,redirect.comment
deprecated-gtfs-0,gtfs,,,,,deprecated-gtfs-0 Some fake company,True,deprecated-gtfs-0 Some fake name,deprecated-gtfs-0 Some fake note,[email protected],,https://deprecated-gtfs-0_some_fake_producer_url,0,,,,https://gtfs-0_some_fake_license_url,,,,,,deprecated,,,
deprecated-gtfs-1,gtfs,,,,,deprecated-gtfs-1 Some fake company,True,deprecated-gtfs-1 Some fake name,deprecated-gtfs-1 Some fake note,[email protected],,https://deprecated-gtfs-1_some_fake_producer_url,1,,,,https://gtfs-1_some_fake_license_url,,,,,,deprecated,,,
gtfs-0,gtfs,,CA,Quebec,Laval,gtfs-0 Some fake company,True,gtfs-0 Some fake name,gtfs-0 Some fake note,[email protected],,https://gtfs-0_some_fake_producer_url,0,,,https://url_prefix/gtfs-0/latest.zip,https://gtfs-0_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,Route Colors|Shapes,,
gtfs-1,gtfs,,CA,Quebec,Montreal,gtfs-1 Some fake company,True,gtfs-1 Some fake name,gtfs-1 Some fake note,[email protected],,https://gtfs-1_some_fake_producer_url,0,,,https://url_prefix/gtfs-1/latest.zip,https://gtfs-1_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,Route Colors|Shapes,,
gtfs-2,gtfs,,,,,gtfs-2 Some fake company,True,gtfs-2 Some fake name,gtfs-2 Some fake note,[email protected],,https://gtfs-2_some_fake_producer_url,0,,,,https://gtfs-2_some_fake_license_url,,,,,,inactive,,gtfs-0,Some redirect comment
gtfs-deprecated-0,gtfs,,,,,gtfs-deprecated-0 Some fake company,True,gtfs-deprecated-0 Some fake name,gtfs-deprecated-0 Some fake note,[email protected],,https://gtfs-deprecated-0_some_fake_producer_url,0,,,,https://gtfs-0_some_fake_license_url,,,,,,deprecated,,,
gtfs-deprecated-1,gtfs,,,,,gtfs-deprecated-1 Some fake company,True,gtfs-deprecated-1 Some fake name,gtfs-deprecated-1 Some fake note,[email protected],,https://gtfs-deprecated-1_some_fake_producer_url,1,,,,https://gtfs-1_some_fake_license_url,,,,,,deprecated,,,
gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,True,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,[email protected],gtfs-0,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,,,
gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,True,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,[email protected],gtfs-0|deprecated-gtfs-0,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,,,
gtfs-rt-1,gtfs_rt,vp,,,,gtfs-rt-1 Some fake company,True,gtfs-rt-1 Some fake name,gtfs-rt-1 Some fake note,[email protected],,https://gtfs-rt-1_some_fake_producer_url,1,https://gtfs-rt-1_some_fake_authentication_info_url,gtfs-rt-1_fake_api_key_parameter_name,,https://gtfs-rt-1_some_fake_license_url,,,,,,inactive,,gtfs-rt-0|gtfs-rt-2,comment 1|comment 2
gtfs-rt-2,gtfs_rt,vp,,,,gtfs-rt-2 Some fake company,True,gtfs-rt-2 Some fake name,gtfs-rt-2 Some fake note,[email protected],,https://gtfs-rt-2_some_fake_producer_url,2,https://gtfs-rt-2_some_fake_authentication_info_url,gtfs-rt-2_fake_api_key_parameter_name,,https://gtfs-rt-2_some_fake_license_url,,,,,,active,,,
""" # noqa
Expand Down
Loading