Skip to content

Commit 2a0d914

Browse files
authored
fix: fetch feeds in batches for export_csv
1 parent 68f77fa commit 2a0d914

File tree

6 files changed

+91
-70
lines changed

6 files changed

+91
-70
lines changed

api/src/shared/common/db_utils.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing import Iterator
2+
13
from geoalchemy2 import WKTElement
4+
from sqlalchemy import or_
25
from sqlalchemy import select
36
from sqlalchemy.orm import joinedload, Session
47
from sqlalchemy.orm.query import Query
@@ -14,15 +17,11 @@
1417
Entitytype,
1518
Redirectingid,
1619
)
17-
1820
from shared.feed_filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter
1921
from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter
20-
2122
from .entity_type_enum import EntityType
22-
23-
from sqlalchemy import or_
24-
2523
from .error_handling import raise_internal_http_validation_error, invalid_bounding_coordinates, invalid_bounding_method
24+
from .iter_utils import batched
2625

2726

2827
def get_gtfs_feeds_query(
@@ -75,28 +74,39 @@ def get_gtfs_feeds_query(
7574
return feed_query
7675

7776

78-
def get_all_gtfs_feeds_query(
77+
def get_all_gtfs_feeds(
78+
db_session: Session,
7979
include_wip: bool = False,
80-
db_session: Session = None,
81-
) -> Query[any]:
82-
"""Get the DB query to use to retrieve all the GTFS feeds, filtering out the WIP if needed"""
83-
84-
feed_query = db_session.query(Gtfsfeed)
85-
80+
batch_size: int = 250,
81+
) -> Iterator[Gtfsfeed]:
82+
"""
83+
Fetch all GTFS feeds.
84+
85+
@param db_session: The database session.
86+
@param include_wip: Whether to include or exclude WIP feeds.
87+
@param batch_size: The number of feeds to fetch from the database at a time.
88+
A lower value means less memory but more queries.
89+
90+
@return: The GTFS feeds in an iterator.
91+
"""
92+
feed_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size)
8693
if not include_wip:
87-
feed_query = feed_query.filter(
88-
or_(Gtfsfeed.operational_status == None, Gtfsfeed.operational_status != "wip") # noqa: E711
94+
feed_query = feed_query.filter(Gtfsfeed.operational_status.is_distinct_from("wip"))
95+
96+
for batch in batched(feed_query, batch_size):
97+
stable_ids = (f.stable_id for f in batch)
98+
yield from (
99+
db_session.query(Gtfsfeed)
100+
.filter(Gtfsfeed.stable_id.in_(stable_ids))
101+
.options(
102+
joinedload(Gtfsfeed.gtfsdatasets)
103+
.joinedload(Gtfsdataset.validation_reports)
104+
.joinedload(Validationreport.features),
105+
*get_joinedload_options(),
106+
)
107+
.order_by(Gtfsfeed.stable_id)
89108
)
90109

91-
feed_query = feed_query.options(
92-
joinedload(Gtfsfeed.gtfsdatasets)
93-
.joinedload(Gtfsdataset.validation_reports)
94-
.joinedload(Validationreport.features),
95-
*get_joinedload_options(),
96-
).order_by(Gtfsfeed.stable_id)
97-
98-
return feed_query
99-
100110

101111
def get_gtfs_rt_feeds_query(
102112
limit: int | None,
@@ -161,29 +171,38 @@ def get_gtfs_rt_feeds_query(
161171
return feed_query
162172

163173

164-
def get_all_gtfs_rt_feeds_query(
174+
def get_all_gtfs_rt_feeds(
175+
db_session: Session,
165176
include_wip: bool = False,
166-
db_session: Session = None,
167-
) -> Query:
168-
"""Get the DB query to use to retrieve all the GTFS rt feeds, filtering out the WIP if needed"""
169-
feed_query = db_session.query(Gtfsrealtimefeed)
170-
177+
batch_size: int = 250,
178+
) -> Iterator[Gtfsrealtimefeed]:
179+
"""
180+
Fetch all GTFS realtime feeds.
181+
182+
@param db_session: The database session.
183+
@param include_wip: Whether to include or exclude WIP feeds.
184+
@param batch_size: The number of feeds to fetch from the database at a time.
185+
A lower value means less memory but more queries.
186+
187+
@return: The GTFS realtime feeds in an iterator.
188+
"""
189+
feed_query = db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size)
171190
if not include_wip:
172-
feed_query = feed_query.filter(
173-
or_(
174-
Gtfsrealtimefeed.operational_status == None, # noqa: E711
175-
Gtfsrealtimefeed.operational_status != "wip",
191+
feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status.is_distinct_from("wip"))
192+
193+
for batch in batched(feed_query, batch_size):
194+
stable_ids = (f.stable_id for f in batch)
195+
yield from (
196+
db_session.query(Gtfsrealtimefeed)
197+
.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
198+
.options(
199+
joinedload(Gtfsrealtimefeed.entitytypes),
200+
joinedload(Gtfsrealtimefeed.gtfs_feeds),
201+
*get_joinedload_options(),
176202
)
203+
.order_by(Gtfsfeed.stable_id)
177204
)
178205

179-
feed_query = feed_query.options(
180-
joinedload(Gtfsrealtimefeed.entitytypes),
181-
joinedload(Gtfsrealtimefeed.gtfs_feeds),
182-
*get_joinedload_options(),
183-
).order_by(Gtfsfeed.stable_id)
184-
185-
return feed_query
186-
187206

188207
def apply_bounding_filtering(
189208
query: Query,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from itertools import islice
2+
3+
4+
def batched(iterable, n):
5+
"""
6+
Batch an iterable into tuples of length `n`. The last batch may be shorter.
7+
8+
Based on the implementation in more-itertools and will be built-in once we
9+
switch to Python 3.12+.
10+
"""
11+
it = iter(iterable)
12+
while batch := tuple(islice(it, n)):
13+
yield batch

functions-python/export_csv/src/main.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from shared.helpers.logger import Logger
3131
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed
32-
from shared.common.db_utils import get_all_gtfs_rt_feeds_query, get_all_gtfs_feeds_query
32+
from shared.common.db_utils import get_all_gtfs_rt_feeds, get_all_gtfs_feeds
3333

3434
from shared.helpers.database import Database
3535

@@ -114,33 +114,19 @@ def fetch_feeds() -> Iterator[Dict]:
114114
logging.info(f"Using database {db.database_url}")
115115
try:
116116
with db.start_db_session() as session:
117-
gtfs_feeds_query = get_all_gtfs_feeds_query(
118-
include_wip=False,
119-
db_session=session,
120-
)
121-
122-
gtfs_feeds = gtfs_feeds_query.all()
123-
124-
logging.info(f"Retrieved {len(gtfs_feeds)} GTFS feeds.")
125-
126-
gtfs_rt_feeds_query = get_all_gtfs_rt_feeds_query(
127-
include_wip=False,
128-
db_session=session,
129-
)
130-
131-
gtfs_rt_feeds = gtfs_rt_feeds_query.all()
132-
133-
logging.info(f"Retrieved {len(gtfs_rt_feeds)} GTFS realtime feeds.")
134-
135-
for feed in gtfs_feeds:
117+
feed_count = 0
118+
for feed in get_all_gtfs_feeds(session, include_wip=False):
136119
yield get_feed_csv_data(feed)
120+
feed_count += 1
137121

138-
logging.info(f"Processed {len(gtfs_feeds)} GTFS feeds.")
122+
logging.info(f"Processed {feed_count} GTFS feeds.")
139123

140-
for feed in gtfs_rt_feeds:
124+
rt_feed_count = 0
125+
for feed in get_all_gtfs_rt_feeds(session, include_wip=False):
141126
yield get_gtfs_rt_feed_csv_data(feed)
127+
rt_feed_count += 1
142128

143-
logging.info(f"Processed {len(gtfs_rt_feeds)} GTFS realtime feeds.")
129+
logging.info(f"Processed {rt_feed_count} GTFS realtime feeds.")
144130

145131
except Exception as error:
146132
logging.error(f"Error retrieving feeds: {error}")

functions-python/export_csv/tests/conftest.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def populate_database():
184184
session.add(tu_entitytype)
185185

186186
# GTFS Realtime feeds
187-
for i in range(3):
188-
gtfs_rt_feed = Gtfsrealtimefeed(
187+
gtfs_rt_feeds = [
188+
Gtfsrealtimefeed(
189189
id=fake.uuid4(),
190190
data_type="gtfs_rt",
191191
feed_name=f"gtfs-rt-{i} Some fake name",
@@ -201,7 +201,10 @@ def populate_database():
201201
provider=f"gtfs-rt-{i} Some fake company",
202202
entitytypes=[vp_entitytype, tu_entitytype] if (i == 0) else [vp_entitytype],
203203
)
204-
session.add(gtfs_rt_feed)
204+
for i in range(3)
205+
]
206+
gtfs_rt_feeds[0].gtfs_feeds.append(active_gtfs_feeds[0])
207+
session.add_all(gtfs_rt_feeds)
205208

206209
session.commit()
207210

functions-python/export_csv/tests/test_export_csv_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
gtfs-2,gtfs,,,,,gtfs-2 Some fake company,gtfs-2 Some fake name,gtfs-2 Some fake note,[email protected],,https://gtfs-2_some_fake_producer_url,0,,,,https://gtfs-2_some_fake_license_url,,,,,,inactive,,gtfs-0,Some redirect comment
3030
gtfs-deprecated-0,gtfs,,,,,gtfs-deprecated-0 Some fake company,gtfs-deprecated-0 Some fake name,gtfs-deprecated-0 Some fake note,[email protected],,https://gtfs-deprecated-0_some_fake_producer_url,0,,,,https://gtfs-0_some_fake_license_url,,,,,,deprecated,,,
3131
gtfs-deprecated-1,gtfs,,,,,gtfs-deprecated-1 Some fake company,gtfs-deprecated-1 Some fake name,gtfs-deprecated-1 Some fake note,[email protected],,https://gtfs-deprecated-1_some_fake_producer_url,1,,,,https://gtfs-1_some_fake_license_url,,,,,,deprecated,,,
32-
gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,[email protected],,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,,,,,,,,,
32+
gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,[email protected],gtfs-0,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,,,,,,,,,
3333
gtfs-rt-1,gtfs_rt,vp,,,,gtfs-rt-1 Some fake company,gtfs-rt-1 Some fake name,gtfs-rt-1 Some fake note,[email protected],,https://gtfs-rt-1_some_fake_producer_url,1,https://gtfs-rt-1_some_fake_authentication_info_url,gtfs-rt-1_fake_api_key_parameter_name,,https://gtfs-rt-1_some_fake_license_url,,,,,,,,,
3434
gtfs-rt-2,gtfs_rt,vp,,,,gtfs-rt-2 Some fake company,gtfs-rt-2 Some fake name,gtfs-rt-2 Some fake note,[email protected],,https://gtfs-rt-2_some_fake_producer_url,2,https://gtfs-rt-2_some_fake_authentication_info_url,gtfs-rt-2_fake_api_key_parameter_name,,https://gtfs-rt-2_some_fake_license_url,,,,,,,,,
3535
""" # noqa

functions-python/helpers/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import logging
1919
import os
2020
import threading
21-
from typing import Optional
21+
from typing import Optional, ContextManager
2222

2323
from sqlalchemy import create_engine, text, event, Engine
2424
from sqlalchemy.orm import sessionmaker, Session, mapper, class_mapper
@@ -159,7 +159,7 @@ def _get_session(self, echo: bool) -> "sessionmaker[Session]":
159159
return self._Sessions[echo]
160160

161161
@contextmanager
162-
def start_db_session(self, echo: bool = True):
162+
def start_db_session(self, echo: bool = True) -> ContextManager[Session]:
163163
"""
164164
Context manager to start a database session with optional echo.
165165

0 commit comments

Comments
 (0)