Skip to content

Commit 34ac9d3

Browse files
authored
fix: reduce memory usage in export_csv function. (#1469)
1 parent 65c9292 commit 34ac9d3

File tree

6 files changed

+162
-62
lines changed

6 files changed

+162
-62
lines changed

api/src/shared/common/db_utils.py

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import logging
12
import os
23
from typing import Iterator, List, Dict, Optional
34

45
from geoalchemy2 import WKTElement
56
from sqlalchemy import or_
67
from sqlalchemy import select
7-
from sqlalchemy.orm import joinedload, Session, contains_eager, load_only
8+
from sqlalchemy.orm import joinedload, Session, contains_eager, load_only, selectinload
89
from sqlalchemy.orm.query import Query
910
from sqlalchemy.orm.strategy_options import _AbstractLoad
1011
from sqlalchemy import func
@@ -48,7 +49,7 @@ def get_gtfs_feeds_query(
4849
is_official: bool | None = None,
4950
published_only: bool = True,
5051
include_options_for_joinedload: bool = True,
51-
) -> Query[any]:
52+
) -> Query:
5253
"""Get the DB query to use to retrieve the GTFS feeds.."""
5354
gtfs_feed_filter = GtfsFeedFilter(
5455
stable_id=stable_id,
@@ -159,36 +160,70 @@ def get_all_gtfs_feeds(
159160
160161
:return: The GTFS feeds in an iterator.
161162
"""
162-
batch_size = int(os.getenv("BATCH_SIZE", "500"))
163-
batch_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size)
163+
batch_size = int(os.getenv("BATCH_SIZE", "50"))
164+
165+
# We fetch in small batches and stream results to avoid loading the whole table in memory.
166+
# stream_results=True lets SQLAlchemy iterate rows without buffering them all at once.
167+
# We also clear the session cache between batches (see expunge_all() below) to prevent
168+
# memory from growing indefinitely when many ORM objects are loaded.
169+
batch_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).execution_options(stream_results=True)
164170
if published_only:
165171
batch_query = batch_query.filter(Gtfsfeed.operational_status == "published")
166172

167-
for batch in batched(batch_query, batch_size):
168-
stable_ids = (f.stable_id for f in batch)
173+
processed = 0
174+
175+
for batch_num, batch in enumerate(batched(batch_query, batch_size), start=1):
176+
start_index = processed + 1
177+
end_index = processed + len(batch)
178+
logging.info("Processing feeds %d - %d", start_index, end_index)
179+
180+
# Convert to a list intentionally: we want to "materialize" IDs now to make any cost
181+
# visible here (and keep the logic simple). This also avoids subtle lazy-evaluation
182+
# effects that can hide where time/memory is really spent.
183+
stable_ids = [f.stable_id for f in batch]
184+
if not stable_ids:
185+
processed += len(batch)
186+
continue
187+
169188
if w_extracted_locations_only:
170189
feed_query = apply_most_common_location_filter(db_session.query(Gtfsfeed), db_session)
171-
yield from (
172-
feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids)).options(
173-
joinedload(Gtfsfeed.latest_dataset)
174-
.joinedload(Gtfsdataset.validation_reports)
175-
.joinedload(Validationreport.features),
176-
*get_joinedload_options(include_extracted_location_entities=True),
177-
)
190+
inner_q = feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids)).options(
191+
# See note above: selectinload is chosen for collections to keep memory and row
192+
# counts under control when streaming.
193+
selectinload(Gtfsfeed.latest_dataset)
194+
.selectinload(Gtfsdataset.validation_reports)
195+
.selectinload(Validationreport.features),
196+
selectinload(Gtfsfeed.bounding_box_dataset),
197+
*get_selectinload_options(include_extracted_location_entities=True),
178198
)
179199
else:
180-
yield from (
200+
inner_q = (
181201
db_session.query(Gtfsfeed)
182202
.outerjoin(Gtfsfeed.gtfsdatasets)
183203
.filter(Gtfsfeed.stable_id.in_(stable_ids))
184204
.options(
185-
joinedload(Gtfsfeed.latest_dataset)
186-
.joinedload(Gtfsdataset.validation_reports)
187-
.joinedload(Validationreport.features),
188-
*get_joinedload_options(include_extracted_location_entities=False),
205+
selectinload(Gtfsfeed.latest_dataset)
206+
.selectinload(Gtfsdataset.validation_reports)
207+
.selectinload(Validationreport.features),
208+
selectinload(Gtfsfeed.bounding_box_dataset),
209+
*get_selectinload_options(include_extracted_location_entities=False),
189210
)
190211
)
191212

213+
# Iterate and stream rows out; the options above ensure related data is preloaded in
214+
# a few small queries per batch, rather than one giant join.
215+
for item in inner_q.execution_options(stream_results=True):
216+
yield item
217+
218+
# Clear the Session identity map so objects from this batch can be GC'd. Without this,
219+
# the Session will keep references and memory usage will grow with each batch.
220+
try:
221+
db_session.expunge_all()
222+
except Exception:
223+
logging.getLogger("get_all_gtfs_feeds").exception("Failed to expunge session after batch %d", batch_num)
224+
225+
processed += len(batch)
226+
192227

193228
def get_gtfs_rt_feeds_query(
194229
limit: int | None,
@@ -278,7 +313,10 @@ def get_all_gtfs_rt_feeds(
278313
:return: The GTFS realtime feeds in an iterator.
279314
"""
280315
batched_query = (
281-
db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size)
316+
db_session.query(Gtfsrealtimefeed.stable_id)
317+
.order_by(Gtfsrealtimefeed.stable_id)
318+
.yield_per(batch_size)
319+
.execution_options(stream_results=True)
282320
)
283321
if published_only:
284322
batched_query = batched_query.filter(Gtfsrealtimefeed.operational_status == "published")
@@ -290,8 +328,8 @@ def get_all_gtfs_rt_feeds(
290328
yield from (
291329
feed_query.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
292330
.options(
293-
joinedload(Gtfsrealtimefeed.entitytypes),
294-
joinedload(Gtfsrealtimefeed.gtfs_feeds),
331+
selectinload(Gtfsrealtimefeed.entitytypes),
332+
selectinload(Gtfsrealtimefeed.gtfs_feeds),
295333
*get_joinedload_options(include_extracted_location_entities=True),
296334
)
297335
.order_by(Gtfsfeed.stable_id)
@@ -301,9 +339,9 @@ def get_all_gtfs_rt_feeds(
301339
db_session.query(Gtfsrealtimefeed)
302340
.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
303341
.options(
304-
joinedload(Gtfsrealtimefeed.entitytypes),
305-
joinedload(Gtfsrealtimefeed.gtfs_feeds),
306-
*get_joinedload_options(include_extracted_location_entities=False),
342+
selectinload(Gtfsrealtimefeed.entitytypes),
343+
selectinload(Gtfsrealtimefeed.gtfs_feeds),
344+
*get_selectinload_options(include_extracted_location_entities=False),
307345
)
308346
)
309347

@@ -319,10 +357,10 @@ def apply_bounding_filtering(
319357
if not bounding_latitudes or not bounding_longitudes or not bounding_filter_method:
320358
return query
321359

322-
if (
323-
len(bounding_latitudes_tokens := bounding_latitudes.split(",")) != 2
324-
or len(bounding_longitudes_tokens := bounding_longitudes.split(",")) != 2
325-
):
360+
# Parse tokens explicitly to satisfy static analyzers and keep error messages clear.
361+
bounding_latitudes_tokens = bounding_latitudes.split(",")
362+
bounding_longitudes_tokens = bounding_longitudes.split(",")
363+
if len(bounding_latitudes_tokens) != 2 or len(bounding_longitudes_tokens) != 2:
326364
raise_internal_http_validation_error(
327365
invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes)
328366
)
@@ -385,6 +423,33 @@ def get_joinedload_options(include_extracted_location_entities: bool = False) ->
385423
]
386424

387425

426+
def get_selectinload_options(include_extracted_location_entities: bool = False) -> List[_AbstractLoad]:
427+
"""
428+
Returns common joinedload options for feeds queries.
429+
:param include_extracted_location_entities: Whether to include extracted location entities.
430+
431+
:return: A list of joinedload options.
432+
"""
433+
# NOTE: For collections we prefer selectinload to avoid row explosion and high memory usage
434+
# during streaming. When callers explicitly join some paths (e.g., most common locations),
435+
# we use contains_eager on that specific path to tell SQLAlchemy the data came from a JOIN.
436+
loaders = []
437+
if include_extracted_location_entities:
438+
loaders.append(contains_eager(Feed.feedosmlocationgroups).joinedload(Feedosmlocationgroup.group))
439+
440+
# collections -> selectinload; scalar relationships can remain joinedload
441+
loaders.extend(
442+
[
443+
selectinload(Feed.locations),
444+
selectinload(Feed.externalids),
445+
selectinload(Feed.feedrelatedlinks),
446+
selectinload(Feed.redirectingids).selectinload(Redirectingid.target),
447+
selectinload(Feed.officialstatushistories),
448+
]
449+
)
450+
return loaders
451+
452+
388453
def get_gbfs_feeds_query(
389454
db_session: Session,
390455
stable_id: Optional[str] = None,

functions-python/export_csv/function_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"name": "export-csv",
33
"description": "Export the DB feed data as a csv file",
44
"entry_point": "export_and_upload_csv",
5-
"timeout": 600,
5+
"timeout": 3600,
66
"memory": "2Gi",
77
"trigger_http": true,
88
"include_folders": ["helpers", "dataset_service"],

functions-python/export_csv/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ python-dotenv==1.0.0
2727

2828
# Other dependencies
2929
natsort
30-
30+
psutil

functions-python/export_csv/src/main.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#
1616
import argparse
1717
import csv
18-
import logging
1918
import os
2019
import re
2120
from typing import Dict, Iterator, Optional
@@ -28,6 +27,7 @@
2827
from google.cloud import storage
2928
from geoalchemy2.shape import to_shape
3029

30+
from shared.helpers.runtime_metrics import track_metrics
3131
from shared.database.database import with_db_session
3232
from shared.helpers.logger import init_logger
3333
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed, Feed
@@ -39,6 +39,8 @@
3939

4040
from shared.database_gen.sqlacodegen_models import Geopolygon
4141

42+
import logging
43+
4244
load_dotenv()
4345
csv_default_file_path = "./output.csv"
4446
init_logger()
@@ -124,6 +126,7 @@ def export_and_upload_csv(_):
124126
return "Export successful"
125127

126128

129+
@track_metrics(metrics=("time", "memory", "cpu"))
127130
def export_csv(csv_file_path: str):
128131
"""
129132
Write feed data to a local CSV file.
@@ -318,7 +321,7 @@ def get_feed_csv_data(
318321
) -> Dict:
319322
"""
320323
This function takes a generic feed and returns a dictionary with the data to be written to the CSV file.
321-
Any specific data (for GTFS or GTFS_RT has to be added after this call.
324+
Any specific data (for GTFS or GTFS_RT) has to be added after this call.
322325
"""
323326

324327
redirect_ids = []
@@ -409,15 +412,24 @@ def get_gtfs_rt_feed_csv_data(
409412
static_references = ""
410413
first_feed_reference = None
411414
if feed.gtfs_feeds:
412-
valid_feed_references = [
413-
feed_reference.stable_id.strip()
414-
for feed_reference in feed.gtfs_feeds
415-
if feed_reference and feed_reference.stable_id
415+
# Prefer active feeds first using a stable sort so original relative order
416+
# within active and inactive groups is preserved.
417+
def _is_active(fr):
418+
try:
419+
return getattr(fr, "status", None) == "active"
420+
except Exception:
421+
return False
422+
423+
# Filter to valid references, then stable sort by active flag (True > False)
424+
valid_refs = [
425+
fr for fr in feed.gtfs_feeds if fr and getattr(fr, "stable_id", None)
416426
]
427+
sorted_refs = sorted(valid_refs, key=_is_active, reverse=True)
428+
429+
valid_feed_references = [fr.stable_id.strip() for fr in sorted_refs]
417430
static_references = "|".join(valid_feed_references)
418-
# If there is more than one GTFS feeds associated with this RT feed (why?)
419-
# We will arbitrarily use the first one in the list for the bounding box.
420-
first_feed_reference = feed.gtfs_feeds[0] if feed.gtfs_feeds else None
431+
# First reference (after sort) will be an active one if any are present
432+
first_feed_reference = sorted_refs[0] if sorted_refs else None
421433
data["static_reference"] = static_references
422434

423435
# For the RT feed, we use the bounding box of the associated GTFS feed, if any.

functions-python/export_csv/tests/conftest.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,30 @@ def populate_database(db_session):
5050
fake = Faker()
5151

5252
feeds = []
53+
54+
# Put the deprecated feeds before the active feeds in the DB so they will be listed first
55+
# in GtfsRealtimeFeed.gtfs_feeds (the RT feed references). This allows testing that active feeds
56+
# are put first in GtfsRealtimeFeed.gtfs_feeds. Admittedly, it's a bit weak but it works for now.
57+
for i in range(2):
58+
feed = Gtfsfeed(
59+
id=fake.uuid4(),
60+
data_type="gtfs",
61+
feed_name=f"deprecated-gtfs-{i} Some fake name",
62+
note=f"deprecated-gtfs-{i} Some fake note",
63+
producer_url=f"https://deprecated-gtfs-{i}_some_fake_producer_url",
64+
authentication_type="0" if (i == 0) else "1",
65+
authentication_info_url=None,
66+
api_key_parameter_name=None,
67+
license_url=f"https://gtfs-{i}_some_fake_license_url",
68+
stable_id=f"deprecated-gtfs-{i}",
69+
status="deprecated",
70+
feed_contact_email=f"deprecated-gtfs-{i}[email protected]",
71+
provider=f"deprecated-gtfs-{i} Some fake company",
72+
operational_status="published",
73+
official=True,
74+
)
75+
db_session.add(feed)
76+
5377
# We create 3 feeds. The first one is active. The third one is inactive and redirected to the first one.
5478
# The second one is active but not redirected.
5579
# First fill the generic parameters
@@ -97,25 +121,6 @@ def populate_database(db_session):
97121
for feed in feeds:
98122
db_session.add(feed)
99123
db_session.flush()
100-
for i in range(2):
101-
feed = Gtfsfeed(
102-
id=fake.uuid4(),
103-
data_type="gtfs",
104-
feed_name=f"gtfs-deprecated-{i} Some fake name",
105-
note=f"gtfs-deprecated-{i} Some fake note",
106-
producer_url=f"https://gtfs-deprecated-{i}_some_fake_producer_url",
107-
authentication_type="0" if (i == 0) else "1",
108-
authentication_info_url=None,
109-
api_key_parameter_name=None,
110-
license_url=f"https://gtfs-{i}_some_fake_license_url",
111-
stable_id=f"gtfs-deprecated-{i}",
112-
status="deprecated",
113-
feed_contact_email=f"gtfs-deprecated-{i}[email protected]",
114-
provider=f"gtfs-deprecated-{i} Some fake company",
115-
operational_status="published",
116-
official=True,
117-
)
118-
db_session.add(feed)
119124

120125
location_entity = Location(id="CA-quebec-montreal")
121126

@@ -273,10 +278,28 @@ def populate_database(db_session):
273278
entitytypes=[vp_entitytype, tu_entitytype] if i == 0 else [vp_entitytype],
274279
operational_status="published",
275280
official=True,
276-
gtfs_feeds=[active_gtfs_feeds[0]] if i == 0 else [],
281+
# Do not attach GTFS feeds at creation; we'll set them in a controlled order below
282+
# gtfs_feeds=[],
277283
)
278284
gtfs_rt_feeds.append(feed)
279285

286+
db_session.add_all(gtfs_rt_feeds)
287+
288+
# --- Attach both a deprecated GTFS feed and an active GTFS feed to the first RT feed
289+
try:
290+
deprecated_feeds = (
291+
db_session.query(Gtfsfeed)
292+
.filter(Gtfsfeed.status == "deprecated")
293+
.order_by(Gtfsfeed.stable_id)
294+
.all()
295+
)
296+
if deprecated_feeds:
297+
gtfs_rt_feeds[0].gtfs_feeds = [deprecated_feeds[0], active_gtfs_feeds[0]]
298+
db_session.flush()
299+
except Exception:
300+
# Best effort in test setup; if it fails the rest of the tests will surface the issue.
301+
pass
302+
280303
# Add redirecting IDs (from main branch logic)
281304
gtfs_rt_feeds[1].redirectingids = [
282305
Redirectingid(

functions-python/export_csv/tests/test_export_csv_main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
# the data is correct.
2525
expected_csv = """
2626
id,data_type,entity_type,location.country_code,location.subdivision_name,location.municipality,provider,is_official,name,note,feed_contact_email,static_reference,urls.direct_download,urls.authentication_type,urls.authentication_info,urls.api_key_parameter_name,urls.latest,urls.license,location.bounding_box.minimum_latitude,location.bounding_box.maximum_latitude,location.bounding_box.minimum_longitude,location.bounding_box.maximum_longitude,location.bounding_box.extracted_on,status,features,redirect.id,redirect.comment
27+
deprecated-gtfs-0,gtfs,,,,,deprecated-gtfs-0 Some fake company,True,deprecated-gtfs-0 Some fake name,deprecated-gtfs-0 Some fake note,[email protected],,https://deprecated-gtfs-0_some_fake_producer_url,0,,,,https://gtfs-0_some_fake_license_url,,,,,,deprecated,,,
28+
deprecated-gtfs-1,gtfs,,,,,deprecated-gtfs-1 Some fake company,True,deprecated-gtfs-1 Some fake name,deprecated-gtfs-1 Some fake note,[email protected],,https://deprecated-gtfs-1_some_fake_producer_url,1,,,,https://gtfs-1_some_fake_license_url,,,,,,deprecated,,,
2729
gtfs-0,gtfs,,CA,Quebec,Laval,gtfs-0 Some fake company,True,gtfs-0 Some fake name,gtfs-0 Some fake note,[email protected],,https://gtfs-0_some_fake_producer_url,0,,,https://url_prefix/gtfs-0/latest.zip,https://gtfs-0_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,Route Colors|Shapes,,
2830
gtfs-1,gtfs,,CA,Quebec,Montreal,gtfs-1 Some fake company,True,gtfs-1 Some fake name,gtfs-1 Some fake note,[email protected],,https://gtfs-1_some_fake_producer_url,0,,,https://url_prefix/gtfs-1/latest.zip,https://gtfs-1_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,Route Colors|Shapes,,
2931
gtfs-2,gtfs,,,,,gtfs-2 Some fake company,True,gtfs-2 Some fake name,gtfs-2 Some fake note,[email protected],,https://gtfs-2_some_fake_producer_url,0,,,,https://gtfs-2_some_fake_license_url,,,,,,inactive,,gtfs-0,Some redirect comment
30-
gtfs-deprecated-0,gtfs,,,,,gtfs-deprecated-0 Some fake company,True,gtfs-deprecated-0 Some fake name,gtfs-deprecated-0 Some fake note,[email protected],,https://gtfs-deprecated-0_some_fake_producer_url,0,,,,https://gtfs-0_some_fake_license_url,,,,,,deprecated,,,
31-
gtfs-deprecated-1,gtfs,,,,,gtfs-deprecated-1 Some fake company,True,gtfs-deprecated-1 Some fake name,gtfs-deprecated-1 Some fake note,[email protected],,https://gtfs-deprecated-1_some_fake_producer_url,1,,,,https://gtfs-1_some_fake_license_url,,,,,,deprecated,,,
32-
gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,True,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,[email protected],gtfs-0,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,,,
32+
gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,True,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,[email protected],gtfs-0|deprecated-gtfs-0,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,-9.0,9.0,-18.0,18.0,2025-01-12 00:00:00+00:00,active,,,
3333
gtfs-rt-1,gtfs_rt,vp,,,,gtfs-rt-1 Some fake company,True,gtfs-rt-1 Some fake name,gtfs-rt-1 Some fake note,[email protected],,https://gtfs-rt-1_some_fake_producer_url,1,https://gtfs-rt-1_some_fake_authentication_info_url,gtfs-rt-1_fake_api_key_parameter_name,,https://gtfs-rt-1_some_fake_license_url,,,,,,inactive,,gtfs-rt-0|gtfs-rt-2,comment 1|comment 2
3434
gtfs-rt-2,gtfs_rt,vp,,,,gtfs-rt-2 Some fake company,True,gtfs-rt-2 Some fake name,gtfs-rt-2 Some fake note,[email protected],,https://gtfs-rt-2_some_fake_producer_url,2,https://gtfs-rt-2_some_fake_authentication_info_url,gtfs-rt-2_fake_api_key_parameter_name,,https://gtfs-rt-2_some_fake_license_url,,,,,,active,,,
3535
""" # noqa

0 commit comments

Comments
 (0)